# MIT License

# Copyright (c) 2024 dechin

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

# cythonize -i -f wrapper.pyx
import os
import site
from pathlib import Path
cimport cython
from cython.parallel import prange
from libc.math cimport log

site_path_ = site.getsitepackages()[0]
lib_path = str(Path(site_path_) / 'CyFES.libs')
user_site_path_ = site.USER_SITE
user_lib_path = str(Path(user_site_path_) / 'CyFES.libs')
site_path = str(Path(site_path_).parent.parent.parent)
user_site_path = str(Path(user_site_path_).parent.parent.parent)

if os.environ['LD_LIBRARY_PATH']:
    os.environ['LD_LIBRARY_PATH'] += os.pathsep + site_path_ + os.pathsep + site_path + os.pathsep + lib_path
else:
    os.environ['LD_LIBRARY_PATH'] = site_path_ + os.pathsep + site_path + os.pathsep + lib_path
os.environ['LD_LIBRARY_PATH'] += os.pathsep + user_site_path_ + os.pathsep + user_site_path + os.pathsep + user_lib_path

import numpy as np
cimport numpy as np

cdef double PI = 3.14159265359
cdef double SQRT2PI = 2.506628274631
cdef double SQRT2PI3 = 15.74961
cdef double kT = 0.596128107

cdef extern from "<dlfcn.h>" nogil:
    void *dlopen(const char *, int)
    char *dlerror()
    void *dlsym(void *, const char *)
    int dlclose(void *)
    enum:
        RTLD_LAZY

ctypedef struct CUstream_st: 
    void* ptr
ctypedef CUstream_st* cudaStream_t 

ctypedef struct CRD:
    double x, y, z

ctypedef struct PATH:
    CRD crds

ctypedef int (*GetWeightFunc)(int, double*, double, double*) noexcept nogil
ctypedef int (*GetDistFunc)(int, CRD*, PATH*, double*) noexcept nogil
ctypedef int (*GaussDistFunc)(int, CRD*, PATH*, double*, double*) noexcept nogil
ctypedef PATH* (*StickCv)(int CV_LENGTH, PATH* cv, int device_id) noexcept nogil
ctypedef int (*ReleaseCv)(PATH*, int device_id) noexcept nogil
ctypedef int (*FastGaussDistFunc)(int, CRD*, PATH*, double*, double*) noexcept nogil
ctypedef void (*StreamCreate)(cudaStream_t *) noexcept nogil
ctypedef void (*StreamDestroy)(cudaStream_t) noexcept nogil
ctypedef void (*StreamSynchronize)(cudaStream_t) noexcept nogil
ctypedef int (*StreamGaussDistFunc)(int, CRD*, PATH*, double*, double*, cudaStream_t) noexcept nogil
ctypedef int (*DeviceGaussFunc)(int, CRD*, PATH*, double*, double*, int, cudaStream_t) noexcept nogil

cufes_path = site_path + "/cyfes/libcufes.so"
if not os.path.exists(cufes_path):
    cufes_path = site_path_ + "/cyfes/libcufes.so"
if not os.path.exists(cufes_path):
    cufes_path = user_site_path + "/cyfes/libcufes.so"
if not os.path.exists(cufes_path):
    cufes_path = user_site_path_ + "/cyfes/libcufes.so"
if not os.path.exists(cufes_path):
    raise ValueError('No dynamic link file libcufes found!')

cdef void* handle = dlopen(cufes_path.encode('utf-8'), RTLD_LAZY)

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef double[:] get_weight(double[:] bias):
    cdef:
        GetWeightFunc GetWeight
        int success
        int CV_LENGTH = bias.shape[0]
        double shift = 0.0
        double[:] weight = np.zeros((CV_LENGTH, ))
    GetWeight = <GetWeightFunc>dlsym(handle, "GetWeight")
    success = GetWeight(CV_LENGTH, &bias[0], shift, &weight[0])
    return weight

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef int get_weight_nogil(double[:] bias, double[:] weight) nogil:
    cdef:
        GetWeightFunc GetWeight
        int success
        int CV_LENGTH = bias.shape[0]
        double shift = 0.0
    GetWeight = <GetWeightFunc>dlsym(handle, "GetWeight")
    success = GetWeight(CV_LENGTH, &bias[0], shift, &weight[0])
    return success

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef double[:] get_dis(double[:] crd, double[:, :] cv):
    cdef:
        GetDistFunc GetDist
        int success
        int CV_LENGTH = cv.shape[0]
        double[:] dis = np.zeros((CV_LENGTH, ))
    GetDist = <GetDistFunc>dlsym(handle, 'GetDist')
    success = GetDist(CV_LENGTH, <CRD*>&crd[0], <PATH*>&cv[0][0], &dis[0])
    return dis

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef double[:, :] batch_get_dis(double[:, :] crd, double[:, :] cv):
    cdef:
        GetDistFunc GaussGetDist
        int success, i
        int CV_LENGTH = cv.shape[0]
        int CRD_LENGTH = crd.shape[0]
        double[:, :] dis = np.zeros((CRD_LENGTH, CV_LENGTH))
    GaussGetDist = <GetDistFunc>dlsym(handle, 'GaussGetDist')
    for i in range(CRD_LENGTH):
        success = GaussGetDist(CV_LENGTH, <CRD*>&crd[i][0], <PATH*>&cv[0][0], &dis[i][0])
    return dis

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef double[:] PathFES(double[:, :] crd, double[:, :] cv, double[:] bw, double[:] bias):
    cdef:
        double volume, res_min
        double[:] weight
        GaussDistFunc GaussGetDist
        int success, i
        int CV_LENGTH = cv.shape[0]
        int CRD_LENGTH = crd.shape[0]
        double[:] height = np.zeros((CV_LENGTH, ))
        double[:, :] dis = np.zeros((CRD_LENGTH, CV_LENGTH))
        double[:] res = np.zeros((CRD_LENGTH, ))
    volume = SQRT2PI3 * bw[0] * bw[1] * bw[2]
    weight = get_weight(bias)
    for i in range(CV_LENGTH):
        height[i] = weight[i] / volume
    GaussGetDist = <GaussDistFunc>dlsym(handle, 'GaussGetDistHeight')
    for i in range(CRD_LENGTH):
        success = GaussGetDist(CV_LENGTH, <CRD*>&crd[i][0], <PATH*>&cv[0][0], &dis[i][0], &height[0])
    for i in range(CRD_LENGTH):
        res[i] = -kT * np.log(np.sum(dis[i]) / CV_LENGTH)
    res_min = np.min(res)
    for i in range(CRD_LENGTH):
        res[i] -= res_min
    return res

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef double[:] FastPathFES(double[:, :] crd, double[:, :] cv, double[:] bw, double[:] bias):
    cdef:
        double volume, res_min
        double[:] weight
        StreamGaussDistFunc GaussGetDist = <StreamGaussDistFunc>dlsym(handle, 'GaussGetDistHeightStream')
        StickCv to_cuda = <StickCv>dlsym(handle, 'StickCv')
        ReleaseCv free_cuda = <ReleaseCv>dlsym(handle, 'ReleaseCv')
        StreamCreate cudaStreamCreate = <StreamCreate>dlsym(handle, 'cudaStreamCreate')
        StreamDestroy cudaStreamDestroy = <StreamDestroy>dlsym(handle, 'cudaStreamDestroy')
        PATH* cv_device
        int success, i
        int device_id = 0
        int CV_LENGTH = cv.shape[0]
        int CRD_LENGTH = crd.shape[0]
        int StreamSize = 2
        double[:] height = np.zeros((CV_LENGTH, ))
        double[:] dis = np.zeros((CV_LENGTH, ))
        double[:] res = np.zeros((CRD_LENGTH, ))
        int[:] sidx = np.arange(StreamSize).astype(np.int32)
        cudaStream_t Stream1 = <CUstream_st*>&sidx[0]
        cudaStream_t Stream2 = <CUstream_st*>&sidx[1]

    volume = SQRT2PI3 * bw[0] * bw[1] * bw[2]
    weight = get_weight(bias)
    for i in range(CV_LENGTH):
        height[i] = weight[i] / volume
    cv_device = to_cuda(CV_LENGTH, <PATH*>&cv[0][0], device_id)

    cudaStreamCreate(&Stream1)
    cudaStreamCreate(&Stream2)

    for i in range(CRD_LENGTH):
        if i % StreamSize == 0:
            success = GaussGetDist(CV_LENGTH, <CRD*>&crd[i][0], cv_device, &dis[0], &height[0], Stream1)
        else:
            success = GaussGetDist(CV_LENGTH, <CRD*>&crd[i][0], cv_device, &dis[0], &height[0], Stream2)
        res[i] = -kT * np.log(np.sum(dis, axis=-1) / CV_LENGTH)

    res -= np.min(res)

    cudaStreamDestroy(Stream1)
    cudaStreamDestroy(Stream2)

    success = free_cuda(cv_device, device_id)
    
    return res

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef double[:] StreamPathFES(double[:, :] crd, double[:, :] cv, double[:] bw, double[:] bias):
    cdef:
        double volume, res_min
        double[:] weight
        StreamGaussDistFunc GaussGetDist = <StreamGaussDistFunc>dlsym(handle, 'GaussGetDistHeightStream')
        StickCv to_cuda = <StickCv>dlsym(handle, 'StickCv')
        ReleaseCv free_cuda = <ReleaseCv>dlsym(handle, 'ReleaseCv')
        StreamCreate cudaStreamCreate = <StreamCreate>dlsym(handle, 'cudaStreamCreate')
        StreamDestroy cudaStreamDestroy = <StreamDestroy>dlsym(handle, 'cudaStreamDestroy')
        StreamSynchronize cudaStreamSynchronize = <StreamSynchronize>dlsym(handle, 'cudaStreamSynchronize')
        PATH* cv_device
        int success, i
        int device_id = 0
        int CV_LENGTH = cv.shape[0]
        int CRD_LENGTH = crd.shape[0]
        int StreamSize = 2
        double[:] height = np.zeros((CV_LENGTH, ))
        double[:] dis = np.zeros((CV_LENGTH, ))
        double[:] res = np.zeros((CRD_LENGTH, ))
        int[:] sidx = np.arange(StreamSize).astype(np.int32)
        cudaStream_t Stream1 = <CUstream_st*>&sidx[0]
        cudaStream_t Stream2 = <CUstream_st*>&sidx[1]

    volume = SQRT2PI3 * bw[0] * bw[1] * bw[2]
    weight = get_weight(bias)
    for i in range(CV_LENGTH):
        height[i] = weight[i] / volume
    cv_device = to_cuda(CV_LENGTH, <PATH*>&cv[0][0], device_id)

    cudaStreamCreate(&Stream1)
    cudaStreamCreate(&Stream2)

    for i in range(CRD_LENGTH):
        if i % StreamSize == 0:
            success = GaussGetDist(CV_LENGTH, <CRD*>&crd[i][0], cv_device, &dis[0], &height[0], Stream1)
        else:
            success = GaussGetDist(CV_LENGTH, <CRD*>&crd[i][0], cv_device, &dis[0], &height[0], Stream2)
        res[i] = -kT * np.log(np.sum(dis, axis=-1) / CV_LENGTH)
        
    cudaStreamSynchronize(Stream1)
    cudaStreamSynchronize(Stream2)

    res -= np.min(res)

    cudaStreamDestroy(Stream1)
    cudaStreamDestroy(Stream2)

    success = free_cuda(cv_device, device_id)
    
    return res

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef int SingleDevicePathFES(double[:, :] crd, double[:, :] cv, double[:] bw, double[:] bias, int device_id, double[:] res):
    cdef:
        double volume, res_min
        DeviceGaussFunc GaussGetDist = <DeviceGaussFunc>dlsym(handle, 'GaussHeightDevice')
        StickCv to_cuda = <StickCv>dlsym(handle, 'StickCv')
        ReleaseCv free_cuda = <ReleaseCv>dlsym(handle, 'ReleaseCv')
        StreamCreate cudaStreamCreate = <StreamCreate>dlsym(handle, 'cudaStreamCreate')
        StreamDestroy cudaStreamDestroy = <StreamDestroy>dlsym(handle, 'cudaStreamDestroy')
        StreamSynchronize cudaStreamSynchronize = <StreamSynchronize>dlsym(handle, 'cudaStreamSynchronize')
        PATH* cv_device
        int success, i
        int CV_LENGTH = cv.shape[0]
        int CRD_LENGTH = crd.shape[0]
        int StreamSize = 2
        double[:] height = np.zeros((CV_LENGTH, ))
        double[:] weight = np.zeros((CV_LENGTH,))
        double[:] dis = np.zeros((CV_LENGTH, ))
        int[:] sidx = np.arange(StreamSize).astype(np.int32)
        cudaStream_t Stream1 = <CUstream_st*>&sidx[0]
        cudaStream_t Stream2 = <CUstream_st*>&sidx[1]

    volume = SQRT2PI3 * bw[0] * bw[1] * bw[2]
    success = get_weight_nogil(bias, weight)

    for i in range(CV_LENGTH):
        height[i] = weight[i] / volume
    
    cv_device = to_cuda(CV_LENGTH, <PATH*>&cv[0][0], device_id)
    
    cudaStreamCreate(&Stream1)
    cudaStreamCreate(&Stream2)
    for i in range(CRD_LENGTH):
        if i % StreamSize == 0:
            success = GaussGetDist(CV_LENGTH, <CRD*>&crd[i][0], cv_device, &dis[0], &height[0], device_id, Stream1)
        else:
            success = GaussGetDist(CV_LENGTH, <CRD*>&crd[i][0], cv_device, &dis[0], &height[0], device_id, Stream2)
        res[i] = -kT * np.log(np.sum(dis, axis=-1) / CV_LENGTH)
    cudaStreamSynchronize(Stream1)
    cudaStreamSynchronize(Stream2)
    res -= np.min(res)
    cudaStreamDestroy(Stream1)
    cudaStreamDestroy(Stream2)
    success = free_cuda(cv_device, device_id)
    
    return success

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef double SingleDevicePathFESNogil(double[:, :] crd, double[:, :] cv, double[:] bw, double[:] bias, int device_id, double[:] res,
                                      double[:] height, double[:,:] dis) nogil:
    cdef:
        double volume, res_min
        DeviceGaussFunc GaussGetDist = <DeviceGaussFunc>dlsym(handle, 'GaussHeightDevice')
        StickCv to_cuda = <StickCv>dlsym(handle, 'StickCv')
        ReleaseCv free_cuda = <ReleaseCv>dlsym(handle, 'ReleaseCv')
        StreamCreate cudaStreamCreate = <StreamCreate>dlsym(handle, 'cudaStreamCreate')
        StreamDestroy cudaStreamDestroy = <StreamDestroy>dlsym(handle, 'cudaStreamDestroy')
        StreamSynchronize cudaStreamSynchronize = <StreamSynchronize>dlsym(handle, 'cudaStreamSynchronize')
        PATH* cv_device
        int success, i, j
        int CV_LENGTH = cv.shape[0]
        int CRD_LENGTH = crd.shape[0]
        int StreamSize = 2
        int _stream1 = 0
        int _stream2 = 1
        double total_dis_0 = 0
        double total_dis_1 = 0
        double min_res = 999999
        double min_res_1 = 999999
        cudaStream_t Stream1 = <CUstream_st*>&_stream1
        cudaStream_t Stream2 = <CUstream_st*>&_stream2
    
    cv_device = to_cuda(CV_LENGTH, <PATH*>&cv[0][0], device_id)

    cudaStreamCreate(&Stream1)
    cudaStreamCreate(&Stream2)

    for i in prange(CRD_LENGTH, num_threads=StreamSize):
        if i % StreamSize == 0:
            success = GaussGetDist(CV_LENGTH, <CRD*>&crd[i][0], cv_device, &dis[0][0], &height[0], device_id, Stream1)
            total_dis_0 = 0
            for j in prange(CV_LENGTH):
                total_dis_0 += dis[0][j]
            res[i] = -kT * log(total_dis_0 / CV_LENGTH)
        else:
            success = GaussGetDist(CV_LENGTH, <CRD*>&crd[i][0], cv_device, &dis[1][0], &height[0], device_id, Stream2)
            total_dis_1 = 0
            for j in prange(CV_LENGTH):
                total_dis_1 += dis[1][j]
            res[i] = -kT * log(total_dis_1 / CV_LENGTH)
    for i in range(CRD_LENGTH):
        if res[i] < min_res:
            min_res = res[i]
    
    cudaStreamSynchronize(Stream1)
    cudaStreamSynchronize(Stream2)
    cudaStreamDestroy(Stream1)
    cudaStreamDestroy(Stream2)
    success = free_cuda(cv_device, device_id)
    return min_res

@cython.boundscheck(False)
@cython.wraparound(False)
cpdef double[:] DevicePathFES(double[:, :] crd, double[:, :] cv, double[:] bw, double[:] bias, int[:] device_ids):
    cdef:
        int i
        int StreamSize = 2
        int device_nums = device_ids.shape[0]
        int CRD_LENGTH = crd.shape[0]
        double[:] res = np.zeros((CRD_LENGTH, ))
        int segment_size = CRD_LENGTH // device_nums + 1
        int last_idx
        int CV_LENGTH = cv.shape[0]
        double[:] height = np.zeros((CV_LENGTH, ))
        double[:] weight = np.zeros((CV_LENGTH, ))
        double[:,:,:] dis = np.zeros((device_nums, StreamSize, CV_LENGTH))
        double[:] min_res = np.zeros((device_nums, ))
    
    volume = SQRT2PI3 * bw[0] * bw[1] * bw[2]
    success = get_weight_nogil(bias, weight)
    for i in range(CV_LENGTH):
        height[i] = weight[i] / volume

    for i in prange(device_nums, nogil=True, num_threads=device_nums, schedule='static'):
        if (i+1)*segment_size > CRD_LENGTH:
            last_idx = CRD_LENGTH
        else:
            last_idx = (i+1)*segment_size
        min_res[i] = SingleDevicePathFESNogil(crd[i*segment_size:last_idx], cv, bw, bias, device_ids[i], res[i*segment_size:last_idx],
                                              height, dis[i])

    for i in range(device_nums - 1):
        if min_res[0] > min_res[i+1]:
            min_res[0] = min_res[i+1]
    for i in range(CRD_LENGTH):
        res[i] -= min_res[0]
    return res

# while not True:
#     dlclose(handle)
